#
# Copyright (c) nexB Inc. and others. All rights reserved.
# VulnerableCode is a trademark of nexB Inc.
# SPDX-License-Identifier: Apache-2.0
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
# See https://github.com/nexB/vulnerablecode for support or download.
# See https://aboutcode.org for more information about nexB OSS projects.
#

from typing import Iterable
from urllib.parse import urljoin

from django.db.models import Q
from django.db.models.query import QuerySet

from vulnerabilities.importer import AdvisoryData
from vulnerabilities.importers.nvd import NVDImporter
from vulnerabilities.improver import Improver
from vulnerabilities.improver import Inference
from vulnerabilities.models import Advisory
from vulnerabilities.models import Alias
from vulnerabilities.models import Vulnerability
from vulnerabilities.models import VulnerabilityChangeLog
from vulnerabilities.models import VulnerabilityStatusType
from vulnerabilities.utils import fetch_response
from vulnerabilities.utils import get_item

MITRE_API_URL = "https://cveawg.mitre.org/api/cve/"


class VulnerabilityStatusImprover(Improver):
    """
    Update vulnerability with NVD statues
    """

    improver_name = "NVD CVE Status Improver"

    @property
    def interesting_advisories(self) -> QuerySet:
        return (
            Advisory.objects.filter(Q(created_by=NVDImporter.qualified_name))
            .distinct("aliases")
            .paginated()
        )

    def get_inferences(self, advisory_data: AdvisoryData) -> Iterable[Inference]:
        """
        This is a work-around until we have new style importer and improver
        and this get_inferences function updates the vulnerability status directly
        # TODO: Replace this with new style improvers
        """
        if not advisory_data:
            return []
        aliases = advisory_data.aliases
        # NVD Importer only has one alias in it and this a CVE
        assert len(aliases) == 1
        cve_id = aliases[0]
        if not cve_id.startswith("CVE"):
            return []

        alias = Alias.objects.get(alias=cve_id)
        vulnerabilities = Vulnerability.objects.filter(aliases__alias=alias).distinct()

        for vuln in vulnerabilities:
            url = urljoin(MITRE_API_URL, cve_id)
            current_status = get_status_from_api(url=url)
            if not current_status:
                current_status = VulnerabilityStatusType.PUBLISHED
            old_status = vuln.status
            vuln.status = current_status
            if current_status != old_status:
                VulnerabilityChangeLog.log_improve(
                    improver=VulnerabilityStatusImprover.improver_name,
                    vulnerability=vuln,
                    source_url=url,
                )
            vuln.save()
        return []


def get_status_from_api(url):
    """
    Return the CVE status from the MITRE API
    """
    try:
        response = fetch_response(url=url)
    except Exception as e:
        return
    response = response.json()
    cve_state = get_item(response, "cveMetadata", "state") or None
    tags = get_item(response, "containers", "cna", "tags") or []
    if "disputed" in tags:
        return VulnerabilityStatusType.DISPUTED
    if cve_state and cve_state == "REJECTED":
        return VulnerabilityStatusType.INVALID
    return VulnerabilityStatusType.PUBLISHED
